D2GRs2 源码阅读3-DataSet

├── data
│   ├── dataset.py           # 构建序列 dataset 时间逆序
│   ├── eval.py              # 评估指标,并记录到TensorBoard
│   ├── item_features.py     # 定义`ItemFeatures`数据类
│   ├── preprocessor.py      # 预处理两种数据集(下载、处理等)
│   └── reco_dataset.py      # 获取推荐(训练、评估)数据集
├── trainer
│    └── data_loader.py      # 数据加载器、支持分布式训练
├── train.py                 # 训练脚本
# train.py
dataset = get_reco_dataset(
	dataset_name=dataset_name, # ml-1m
	max_sequence_length=max_sequence_length, # 200
	chronological=True, # 按时间排序
	positional_sampling_ratio=positional_sampling_ratio,                           # 1 按位置采样
)

train_data_sampler, train_data_loader = create_data_loader(
	dataset.train_dataset,
	batch_size=local_batch_size,
	world_size=world_size,
	rank=rank,
	shuffle=True,
	drop_last=world_size > 1,
)

@dataclass装饰器,这个类能够存储一个数字,拥有比大小的功能,很大程度上减少了代码量,很方便。除了上面的整型外,还可以使用其他的类型,包括自己定义的数据类型。深度学习pytorch之dataclass

from dataclasses import dataclass
@dataclass
class RecoDataset:
	max_sequence_length: int
	num_unique_items: int
	max_item_id: int
	all_item_ids: List[int]
	train_dataset: torch.utils.data.Dataset
	eval_dataset: torch.utils.data.Dataset

Dataset是PyTorch提供的一个抽象类,我们可以继承这个类并重写__getitem____len__方法,从而创建自己的数据集。__getitem__方法用于获取单个数据样本,__len__方法则返回数据集的大小。

from torch.utils.data import Dataset
import os

class CustomDataset(Dataset):
	def __init__(self, data_dir, transform=None):
		self.data_dir = data_dir
		self.images = os.listdir(data_dir)

	def __len__(self):
		return len(self.images)

	def __getitem__(self, idx):
		img_path = os.path.join(self.data_dir, self.images[idx])
		image = Image.open(img_path)
		label = idx # 这里为了简化,我们直接用索引作为标签
		return image, label

DataLoader是PyTorch提供的一个数据加载器,它可以从Dataset中读取数据,并以批次的形式提供给模型进行训练。DataLoader的主要参数包括:

from torch.utils.data import DataLoader

# 假设我们已经创建了一个CustomDataset对象
dataset = CustomDataset(data_dir='./data', transform=transform)

# 创建一个DataLoader对象
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# 在训练循环中使用DataLoader
for epoch in range(num_epochs):
	for batch_idx, (data, target) in enumerate (data_loader):
		# 在这里进行模型的训练操作
		pass

Tensorflow模型的格式 - chease - 博客园

PyTorch1.12 亮点一览 | DataPipe + TorchArrow 新的数据加载与处理范式-CSDN博客

读写OSS数据_人工智能平台 PAI(PAI)-阿里云帮助中心